import math
import torch
from torch import nn
from model import embedding
from model.nerf_utils import LinearWithRepeat, MLPWithInputSkips, _xavier_init
from pytorch3d.renderer import RayBundle, ray_bundle_to_ray_points

class NeuralRadianceField(torch.nn.Module):
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        self.linear_color = hparams.color.linear

        color_random_features_args = hparams.color.get('random_features_args', {})
        rays_random_features_args = hparams.rays.get('random_features_args', {})

        self.color_embedding = getattr(embedding, hparams.color.get('embedding_class', None), 'NerfEmbedding')(
                num_features=hparams.color.num_features, dimensions=hparams.color.get('dimensions', {}), **color_random_features_args)
        num_features_point_color = self.color_embedding.output_features

        # ray is normalized 3d vector in world coordinates:
        self.rays_embedding = getattr(embedding, hparams.rays.get('embedding_class', None), 'NerfEmbedding')(
                num_features=hparams.rays.num_features, dimensions=(1, 1, 1), **rays_random_features_args)
        num_features_ray = self.rays_embedding.output_features
        
        if not self.linear_color:
            depth = self.hparams.color.get('mlp_depth', 2)
            n_hidden_neurons = hparams.color.n_hidden_neurons
            self.mlp = MLPWithInputSkips(
                depth,
                num_features_point_color,
                n_hidden_neurons,
                n_hidden_neurons,
                input_skips=self.hparams.color.get('mlp_input_skips', ()),
                init_tricks=hparams.get('init_tricks', False),
                activate_final=True
            )
        else:
            self.mlp = nn.Identity()
            n_hidden_neurons = num_features_point_color
        
        # Given features predicted by self.mlp, self.color_layer
        # is responsible for predicting a 3-D per-point vector
        # that represents the RGB color of the point.
        
        if self.hparams.color.get('final_activation', 'sigmoid') == 'sigmoid':
            color_activate = torch.nn.Sigmoid()
        elif self.hparams.color.get('final_activation', 'sigmoid') == 'no_activation':
            color_activate = nn.Identity()
        else:
            raise ValueError(f'Unknown final_activation for color: {self.hparams.color.get("final_activation", "sigmoid")}')

        if not self.linear_color:
            depth = self.hparams.color.get('with_ray_depth', 2)
            self.color_layer = torch.nn.Sequential(
                LinearWithRepeat(
                    n_hidden_neurons + num_features_ray, n_hidden_neurons
                ),
                torch.nn.ReLU(True),
                MLPWithInputSkips(
                    depth - 1,
                    n_hidden_neurons,
                    3,
                    n_hidden_neurons,
                    input_skips=self.hparams.color.get('with_ray_input_skips', ()),
                    init_tricks=hparams.get('init_tricks', False)
                ),
                color_activate
            )
        else:
            self.color_layer = torch.nn.Sequential(
                torch.nn.Linear(num_features_point_color + num_features_ray, 3),
                color_activate,
                # To ensure that the colors correctly range between [0-1],
                # the layer is terminated with a sigmoid layer.
            )
        
        self._prepare_density_layer(hparams.density, n_hidden_neurons, init_tricks=hparams.get('init_tricks', False))
        
        
    def _prepare_density_layer(self, hparams, n_hidden_neurons, init_tricks=False):
        self.linear_density = hparams.linear
        density_random_features_args = hparams.get('random_features_args', {})

        if self.linear_density:
            self.density_embedding = getattr(embedding, hparams.embedding_class, 'NerfEmbedding')(
                    num_features=hparams.num_features, dimensions=hparams.dimensions, **density_random_features_args)
            num_features_point_density = self.density_embedding.output_features
            density_in_features = num_features_point_density
        else:
            density_in_features = n_hidden_neurons

        act_layer = torch.nn.Softplus(beta=10.0)

        self.density_layer = torch.nn.Sequential(
            torch.nn.Linear(density_in_features, 1),
            act_layer,
        )
        if init_tricks:
            _xavier_init(self.density_layer[0])
            self.density_layer[0].bias.data[0] = 0.0 # -1.5


    def _get_density_features(self, ray_bundle, color_features):
        if self.linear_density:
            density_features = self.density_embedding(
                ray_bundle_to_ray_points(ray_bundle)
            )
        else:
            density_features = color_features
        return density_features


    def _get_densities(self, features, convert_to_alphas=True):
        """
        This function takes `features` predicted by `self.mlp`
        and converts them to `raw_densities` with `self.density_layer`.
        `raw_densities` are later mapped to [0-1] range with
        1 - inverse exponential of `raw_densities`.
        """
        raw_densities = self.density_layer(features)
        if convert_to_alphas:
            return 1 - (-raw_densities).exp()
        return raw_densities
    
    def _get_colors(self, features, rays_directions):
        """
        This function takes per-point `features` predicted by `self.mlp`
        and evaluates the color model in order to attach to each
        point a 3D vector of its RGB color.
        
        In order to represent viewpoint dependent effects,
        before evaluating `self.color_layer`, `NeuralRadianceField`
        concatenates to the `features` a harmonic embedding
        of `ray_directions`, which are per-point directions 
        of point rays expressed as 3D l2-normalized vectors
        in world coordinates.
        """
        spatial_size = features.shape[:-1]
        
        # Normalize the ray_directions to unit l2 norm.
        rays_directions_normed = torch.nn.functional.normalize(
            rays_directions, dim=-1
        )
        
        # Obtain the harmonic embedding of the normalized ray directions.
        rays_embedding = self.rays_embedding(
            rays_directions_normed
        )
        if not self.linear_color:
            colors = self.color_layer((features, rays_embedding))
            return colors
        else:
            # Expand the ray directions tensor so that its spatial size
            # is equal to the size of features.
            rays_embedding_expand = rays_embedding[..., None, :].expand(
                *spatial_size, rays_embedding.shape[-1]
            )
            
            # Concatenate ray direction embeddings with 
            # features and evaluate the color model.
            color_layer_input = torch.cat(
                (features, rays_embedding_expand),
                dim=-1
            )
            color_layer_input_shape = color_layer_input.shape
            color_layer_input = color_layer_input.view(-1, color_layer_input.shape[-1])
            colors = self.color_layer(color_layer_input)
            return colors.view(*color_layer_input_shape[:-1], colors.shape[-1])

    def _get_point_features(self, ray_bundle: RayBundle):
        '''
        This function takes a `ray_bundle` and returns features
        '''
        rays_points_world = ray_bundle_to_ray_points(ray_bundle)
        spatial_size = rays_points_world.shape[:-1]
        color_features = self.color_embedding(
            rays_points_world.view(-1, rays_points_world.shape[-1])
        )
        color_features = color_features.view(
            *spatial_size, color_features.shape[-1]
        )

        if not self.linear_color:
            color_features = self.mlp(color_features)
        return color_features

    def _color(self, ray_bundle: RayBundle):
        '''
        Given a ray bundle, it returns the color of each ray point
        
        :param ray_bundle: RayBundle
        :type ray_bundle: RayBundle
        :return: The color of the rays.
        '''
        color_features = self._get_point_features(ray_bundle)
        rays_colors = self._get_colors(color_features, ray_bundle.directions)
        return rays_colors, color_features
    
  
    def forward(
        self, 
        ray_bundle: RayBundle,
        **kwargs,
    ):
        """
        The forward function accepts the parametrizations of
        3D points sampled along projection rays. The forward
        pass is responsible for attaching a 3D vector
        and a 1D scalar representing the point's 
        RGB color and opacity respectively.
        
        Args:
            ray_bundle: A RayBundle object containing the following variables:
                origins: A tensor of shape `(minibatch, ..., 3)` denoting the
                    origins of the sampling rays in world coords.
                directions: A tensor of shape `(minibatch, ..., 3)`
                    containing the direction vectors of sampling rays in world coords.
                lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
                    containing the lengths at which the rays are sampled.

        Returns:
            rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)`
                denoting the opacity of each ray point.
            rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)`
                denoting the color of each ray point.
        """
        rays_colors, color_features = self._color(ray_bundle)
        # rays_colors.shape = [minibatch x ... x 3]
        
        density_features = self._get_density_features(ray_bundle, color_features)
        rays_densities = self._get_densities(density_features)
        # rays_densities.shape = [minibatch x ... x 1]
        
        return {
            'density': rays_densities, 
            'avg_values': rays_colors,
            't_grid': ray_bundle.lengths
        }
    
    def batched_forward(
        self, 
        ray_bundle: RayBundle,
        split_size: int = 256,
        **kwargs,        
    ):

        # Parse out shapes needed for tensor reshaping in this function.
        #n_pts_per_ray = ray_bundle.lengths.shape[-1]
        batch_shape = ray_bundle.origins.shape[:-1]
        batch_len = batch_shape.numel()


        # Split the rays to `split_size` batches.
        batches = torch.chunk(torch.arange(batch_len), math.ceil(batch_len / self.hparams.get('val_split_size', split_size)))

        # For each batch, execute the standard forward pass.
        batch_outputs = [
            self.forward(
                RayBundle(
                    origins=ray_bundle.origins.view(batch_len, -1)[batch_idx],
                    directions=ray_bundle.directions.view(batch_len, -1)[batch_idx],
                    lengths=ray_bundle.lengths.view(batch_len, -1)[batch_idx],
                    # original_lengths=ray_bundle.original_lengths.view(batch_len, -1)[batch_idx],
                    xys=None,
                )
            ) for batch_idx in batches
        ]
        
        batch_outputs_dict = {
            output_i : torch.cat(
                [batch_output[output_i] for batch_output in batch_outputs], dim=0
            ).view(*batch_shape, -1, 1) for output_i in ['density', 't_grid']
        }
        batch_outputs_dict.update({
            output_i : torch.cat(
                [batch_output[output_i] for batch_output in batch_outputs], dim=0
            ).view(*batch_shape, -1, 3) for output_i in ['avg_values']
        })
        # batch_outputs_dict.update({
        #     output_i : torch.cat(
        #         [batch_output[output_i] for batch_output in batch_outputs], dim=0
        #     ).view(*batch_shape, 1) for output_i in ['opacity']
        # })
        return batch_outputs_dict